#!/usr/bin/env python
from __future__ import print_function
import pickle
import rbm
import numpy
import sys

if sys.version < '3':
    range=xrange

def output_array(a):
    for j in a:
        if   j < 0.001: sys.stdout.write("0")
        elif j > 0.999: sys.stdout.write("1")
        else:           sys.stdout.write(str(j))
        sys.stdout.write(" ")
    sys.stdout.write("\n")

def main():
    if len(sys.argv)<3 or sys.argv[1] == "--help":
        print("Usage: rbmcmd statefile command parameters...")
        print("    rbmcmd statefile init num_visible num_hidden")
        print("    rbmcmd statefile train max_epochs learning_rate < whitespace-separated lines of num_visible numbers 0 or 1.")
        print("    rbmcmd statefile run_visible < whitespace-separated lines of num_visible numbers > whitespace-separated lines of num_hidden numbers")
        print("    rbmcmd statefile run_hidden < whitespace-separated lines of num_hidden numbers > whitespace-separated lines of num_visible numbers")
        print("    rbmcmd statefile daydream_trace num_samples > num_samples whitespace-separted lines of num_visible numbers")
        print("    rbmcmd statefile daydream num_samples num_dreams > num_dreams whitespace-separated lines of num_visible numbers")
        return
    fname = sys.argv[1]
    cmd = sys.argv[2]

    r = None
    try:
        with open(fname, "rb") as f:
            r = pickle.load(f)
    except IOError: pass
    

    if   cmd == "init":
        r = rbm.RBM(num_visible = int(sys.argv[3]), num_hidden = int(sys.argv[4]))
        r.debug_print = False
    elif cmd == "train" or cmd=="run_visible" or cmd=="run_hidden":
        maxchunk=1000

        while True:
            line = None
            incoming_data = []
            for i in range(0, maxchunk):
                line = sys.stdin.readline()
                if not line: break
                if line[0] == "#": continue
                row = map(float, line.split())
                incoming_data.append(row)

            if cmd == "train":
                max_epochs = int(sys.argv[3])
                learning_rate = float(sys.argv[4])
                r.train(numpy.array(incoming_data), max_epochs=max_epochs, learning_rate=learning_rate)
            elif cmd == "run_visible":
                hidden_states = r.run_visible(numpy.array(incoming_data))
                for i in hidden_states:
                    output_array(i)
            elif cmd == "run_hidden":
                visible_states = r.run_hidden(numpy.array(incoming_data))
                for i in visible_states:
                    output_array(i)

            if not line: break
    elif cmd == "daydream_trace":
        num_s = int(sys.argv[3])
        out = r.daydream(num_s)
        for i in out:
            for j in i:
                sys.stdout.write(str(j))
                sys.stdout.write(" ")
            sys.stdout.write("\n")
    elif cmd == "daydream":
        num_s = int(sys.argv[3])
        num_d = int(sys.argv[4])
        for k in range(0,num_d):
            out = r.daydream(num_s)
            lastrow = out[-1]
            output_array(lastrow)
    else:
         print("Unknown command")


    if cmd=='init' or cmd=='train':
        with open(fname, "wb") as f:
            pickle.dump(r, f)
    

if __name__ == '__main__': main()
